{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 文本情感分类问题\n",
    "- 机器学习方法 TFIDF+机器学习分类算法\n",
    "- 深度学习方法 TextCNN TextRNN **预训练的模型**\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 预训练的模型有哪些？\n",
    "- bert\n",
    "- albert\n",
    "- xlnet\n",
    "- robert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0-rc1\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from tqdm import tqdm\n",
    "import tensorflow as tf\n",
    "import tensorflow.keras.backend as K\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "import os\n",
    "from transformers import *\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train shape = (99913, 7)\n",
      "test shape = (10000, 6)\n"
     ]
    }
   ],
   "source": [
    "TRAIN_PATH = './data/train_dataset/'\n",
    "TEST_PATH = './data/test_dataset/'\n",
    "BERT_PATH = './bert_base_chinese/'\n",
    "MAX_SEQUENCE_LENGTH = 140\n",
    "input_categories = '微博中文内容'\n",
    "output_categories = '情感倾向'\n",
    "\n",
    "df_train = pd.read_csv(TRAIN_PATH+'nCoV_100k_train.labled.csv',engine ='python')\n",
    "df_train = df_train[df_train[output_categories].isin(['-1','0','1'])]\n",
    "df_test = pd.read_csv(TEST_PATH+'nCov_10k_test.csv',engine ='python')\n",
    "df_sub = pd.read_csv(TEST_PATH+'submit_example.csv')\n",
    "print('train shape =', df_train.shape)\n",
    "print('test shape =', df_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 特征转化\n",
    "<img src=\"https://imgkr.cn-bj.ufileos.com/341d1c83-45cf-4e9b-a656-a547cd0f2c67.png\" width=\"50%\" height=\"50%\" />\n",
    "\n",
    "transformer包\n",
    "- https://huggingface.co/transformers/v2.5.0/model_doc/bert.html\n",
    "- tokenizer.encode_plus 参数详细见：https://github.com/huggingface/transformers/blob/72768b6b9c2083d9f2d075d80ef199a3eae881d8/src/transformers/tokenization_utils.py#L924 924行\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [101, 3918, 2428, 722, 4706, 102],\n",
       " 'token_type_ids': [0, 0, 0, 0, 0, 0],\n",
       " 'attention_mask': [1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.encode_plus(\"深度之眼\",\n",
    "        add_special_tokens=True,\n",
    "        max_length=20,\n",
    "        truncation_strategy= 'longest_first')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _convert_to_transformer_inputs(instance, tokenizer, max_sequence_length):\n",
    "    \"\"\"Converts tokenized input to ids, masks and segments for transformer (including bert)\"\"\"\n",
    "    \"\"\"默认返回input_ids,token_type_ids,attention_mask\"\"\"\n",
    "    inputs = tokenizer.encode_plus(instance,\n",
    "        add_special_tokens=True,\n",
    "        max_length=max_sequence_length,\n",
    "        truncation_strategy= 'longest_first')\n",
    "\n",
    "    input_ids =  inputs[\"input_ids\"]\n",
    "    input_masks = inputs[\"attention_mask\"]\n",
    "    input_segments = inputs[\"token_type_ids\"]\n",
    "    padding_length = max_sequence_length - len(input_ids)\n",
    "    #填充\n",
    "    padding_id = tokenizer.pad_token_id\n",
    "    input_ids = input_ids + ([padding_id] * padding_length)\n",
    "    input_masks = input_masks + ([0] * padding_length)\n",
    "    input_segments = input_segments + ([0] * padding_length)\n",
    "    return [input_ids, input_masks, input_segments]\n",
    "\n",
    "\n",
    "def compute_input_arrays(df, columns, tokenizer, max_sequence_length):\n",
    "    input_ids, input_masks, input_segments = [], [], []\n",
    "    for instance in tqdm(df[columns]):\n",
    "        \n",
    "        ids, masks, segments = _convert_to_transformer_inputs(str(instance), tokenizer, max_sequence_length)\n",
    "        \n",
    "        input_ids.append(ids)\n",
    "        input_masks.append(masks)\n",
    "        input_segments.append(segments)\n",
    "\n",
    "    return [np.asarray(input_ids, dtype=np.int32), \n",
    "            np.asarray(input_masks, dtype=np.int32), \n",
    "            np.asarray(input_segments, dtype=np.int32)\n",
    "           ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Calling BertTokenizer.from_pretrained() with the path to a single file or url is deprecated\n",
      "100%|███████████████████████████████████████████████████████████████████████████| 99913/99913 [02:02<00:00, 818.74it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [00:12<00:00, 816.24it/s]\n"
     ]
    }
   ],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained(BERT_PATH+'bert-base-chinese-vocab.txt')\n",
    "inputs = compute_input_arrays(df_train, input_categories, tokenizer, MAX_SEQUENCE_LENGTH)\n",
    "test_inputs = compute_input_arrays(df_test, input_categories, tokenizer, MAX_SEQUENCE_LENGTH)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 标签类别转化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_output_arrays(df, columns):\n",
    "    return np.asarray(df[columns].astype(int) + 1)\n",
    "outputs = compute_output_arrays(df_train, output_categories)\n",
    "\n",
    "\n",
    "# from sklearn.preprocessing import LabelEncoder\n",
    "# from tensorflow.keras.utils import to_categorical\n",
    "# lb = LabelEncoder()\n",
    "# train_label = lb.fit_transform(train['class'].values)\n",
    "# train_label = to_categorical(train_label)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BERT模型\n",
    "\n",
    "https://huggingface.co/models\n",
    "\n",
    "\n",
    "<img src=\"https://imgkr.cn-bj.ufileos.com/9115ac01-f455-498b-8c38-9c4abb04046c.png\" width=\"50%\" height=\"50%\" />\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def create_model():\n",
    "    input_id = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)    \n",
    "    input_mask = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)    \n",
    "    input_atn = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)    \n",
    "    config = BertConfig.from_pretrained(BERT_PATH + 'bert-base-chinese-config.json', output_hidden_states=True)\n",
    "    bert_model = TFBertModel.from_pretrained(BERT_PATH+'bert-base-chinese-tf_model.h5', config=config)\n",
    "    sequence_output, pooler_output, hidden_states = bert_model(input_id, attention_mask=input_mask, token_type_ids=input_atn)\n",
    "    #(bs,140,768)(bs,768)\n",
    "    x = tf.keras.layers.GlobalAveragePooling1D()(sequence_output)    \n",
    "    x = tf.keras.layers.Dropout(0.15)(x)\n",
    "    x = tf.keras.layers.Dense(3, activation='softmax')(x)\n",
    "\n",
    "    model = tf.keras.models.Model(inputs=[input_id, input_mask, input_atn], outputs=x)\n",
    "    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)\n",
    "    model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['acc', 'mae'])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#11G 1080ti\n",
    "#colab 16G 11G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 79929 samples, validate on 19984 samples\n",
      "Epoch 1/2\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "79929/79929 [==============================] - 1495s 19ms/sample - loss: 0.6041 - acc: 0.7354 - mae: 0.2347 - val_loss: 0.5676 - val_acc: 0.7512 - val_mae: 0.2265\n",
      "Epoch 2/2\n",
      "79929/79929 [==============================] - 1465s 18ms/sample - loss: 0.5197 - acc: 0.7751 - mae: 0.2041 - val_loss: 0.5948 - val_acc: 0.7307 - val_mae: 0.2236\n",
      "Train on 79929 samples, validate on 19984 samples\n",
      "Epoch 1/2\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "79929/79929 [==============================] - 1496s 19ms/sample - loss: 0.6054 - acc: 0.7376 - mae: 0.2315 - val_loss: 0.6126 - val_acc: 0.7202 - val_mae: 0.2300\n",
      "Epoch 2/2\n",
      "79929/79929 [==============================] - 1470s 18ms/sample - loss: 0.5115 - acc: 0.7796 - mae: 0.1994 - val_loss: 0.6417 - val_acc: 0.7081 - val_mae: 0.2347\n",
      "Train on 79931 samples, validate on 19982 samples\n",
      "Epoch 1/2\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "79931/79931 [==============================] - 1495s 19ms/sample - loss: 0.5866 - acc: 0.7453 - mae: 0.2248 - val_loss: 0.6862 - val_acc: 0.6972 - val_mae: 0.2497\n",
      "Epoch 2/2\n",
      "79931/79931 [==============================] - 1464s 18ms/sample - loss: 0.4890 - acc: 0.7901 - mae: 0.1910 - val_loss: 0.7305 - val_acc: 0.6977 - val_mae: 0.2414\n",
      "Train on 79931 samples, validate on 19982 samples\n",
      "Epoch 1/2\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "79931/79931 [==============================] - 1503s 19ms/sample - loss: 0.6012 - acc: 0.7370 - mae: 0.2329 - val_loss: 0.6116 - val_acc: 0.7384 - val_mae: 0.2147\n",
      "Epoch 2/2\n",
      "79931/79931 [==============================] - 1468s 18ms/sample - loss: 0.5054 - acc: 0.7816 - mae: 0.1984 - val_loss: 0.6154 - val_acc: 0.7369 - val_mae: 0.2182\n",
      "Train on 79932 samples, validate on 19981 samples\n",
      "Epoch 1/2\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "79932/79932 [==============================] - 1500s 19ms/sample - loss: 0.5901 - acc: 0.7421 - mae: 0.2291 - val_loss: 0.6313 - val_acc: 0.7343 - val_mae: 0.2193\n",
      "Epoch 2/2\n",
      "79932/79932 [==============================] - 1469s 18ms/sample - loss: 0.4958 - acc: 0.7865 - mae: 0.1946 - val_loss: 0.6538 - val_acc: 0.7230 - val_mae: 0.2213\n"
     ]
    }
   ],
   "source": [
    "gkf = StratifiedKFold(n_splits=5).split(X=df_train[input_categories].fillna('-1'), y=df_train[output_categories].fillna('-1'))\n",
    "\n",
    "valid_preds = []\n",
    "test_preds = []\n",
    "for fold, (train_idx, valid_idx) in enumerate(gkf):\n",
    "    train_inputs = [inputs[i][train_idx] for i in range(len(inputs))]\n",
    "    train_outputs = to_categorical(outputs[train_idx])\n",
    "\n",
    "    valid_inputs = [inputs[i][valid_idx] for i in range(len(inputs))]\n",
    "    valid_outputs = to_categorical(outputs[valid_idx])\n",
    "\n",
    "    K.clear_session()\n",
    "    model = create_model()\n",
    "    model.fit(train_inputs, train_outputs, validation_data= [valid_inputs, valid_outputs], epochs=2, batch_size=32)\n",
    "    # model.save_weights(f'bert-{fold}.h5')\n",
    "    valid_preds.append(model.predict(valid_inputs))\n",
    "    test_preds.append(model.predict(test_inputs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>微博id</th>\n",
       "      <th>微博发布时间</th>\n",
       "      <th>发布人账号</th>\n",
       "      <th>微博中文内容</th>\n",
       "      <th>微博图片</th>\n",
       "      <th>微博视频</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4456068992182160</td>\n",
       "      <td>01月01日 23:38</td>\n",
       "      <td>-精緻的豬豬女戰士-</td>\n",
       "      <td>#你好2020#新年第一天元气满满的早起出门买早饭结果高估了自己抗冻能力回家成功冻发烧（大概...</td>\n",
       "      <td>['https://ww2.sinaimg.cn/thumb150/745aa591ly1g...</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4456424178427250</td>\n",
       "      <td>01月02日 23:09</td>\n",
       "      <td>liujunyi88</td>\n",
       "      <td>大宝又感冒鼻塞咳嗽了，还有发烧。队友加班几天不回。感觉自己的情绪在家已然是随时引爆的状态。情...</td>\n",
       "      <td>[]</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4456797466940200</td>\n",
       "      <td>01月03日 23:53</td>\n",
       "      <td>ablsa</td>\n",
       "      <td>还要去输两天液，这天也太容易感冒发烧了，一定要多喝热水啊?</td>\n",
       "      <td>['https://ww3.sinaimg.cn/orj360/006fTidCly1gaj...</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4456791021108920</td>\n",
       "      <td>01月03日 23:27</td>\n",
       "      <td>喵吃鱼干Lynn</td>\n",
       "      <td>我太难了别人怎么发烧都没事就我一检查甲型流感?</td>\n",
       "      <td>[]</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4457086404997440</td>\n",
       "      <td>01月04日 19:01</td>\n",
       "      <td>我的发小今年必脱单</td>\n",
       "      <td>果然是要病一场的喽回来第三天开始感冒今儿还发烧了喉咙眼睛都难受的一匹怎么样能不经意让我的毕设...</td>\n",
       "      <td>[]</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               微博id        微博发布时间       发布人账号  \\\n",
       "0  4456068992182160  01月01日 23:38  -精緻的豬豬女戰士-   \n",
       "1  4456424178427250  01月02日 23:09  liujunyi88   \n",
       "2  4456797466940200  01月03日 23:53       ablsa   \n",
       "3  4456791021108920  01月03日 23:27    喵吃鱼干Lynn   \n",
       "4  4457086404997440  01月04日 19:01   我的发小今年必脱单   \n",
       "\n",
       "                                              微博中文内容  \\\n",
       "0  #你好2020#新年第一天元气满满的早起出门买早饭结果高估了自己抗冻能力回家成功冻发烧（大概...   \n",
       "1  大宝又感冒鼻塞咳嗽了，还有发烧。队友加班几天不回。感觉自己的情绪在家已然是随时引爆的状态。情...   \n",
       "2                      还要去输两天液，这天也太容易感冒发烧了，一定要多喝热水啊?   \n",
       "3                            我太难了别人怎么发烧都没事就我一检查甲型流感?   \n",
       "4  果然是要病一场的喽回来第三天开始感冒今儿还发烧了喉咙眼睛都难受的一匹怎么样能不经意让我的毕设...   \n",
       "\n",
       "                                                微博图片 微博视频  \n",
       "0  ['https://ww2.sinaimg.cn/thumb150/745aa591ly1g...   []  \n",
       "1                                                 []   []  \n",
       "2  ['https://ww3.sinaimg.cn/orj360/006fTidCly1gaj...   []  \n",
       "3                                                 []   []  \n",
       "4                                                 []   []  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub = np.average(test_preds, axis=0)\n",
    "sub = np.argmax(sub,axis=1)\n",
    "# df_sub['y'] = sub-1\n",
    "# #df_sub['id'] = df_sub['id'].apply(lambda x: str(x))\n",
    "# df_sub.to_csv('test_sub.csv',index=False, encoding='utf-8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sub = df_test[['微博id']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>微博id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4456068992182160</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4456424178427250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4456797466940200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4456791021108920</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4457086404997440</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               微博id\n",
       "0  4456068992182160\n",
       "1  4456424178427250\n",
       "2  4456797466940200\n",
       "3  4456791021108920\n",
       "4  4457086404997440"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_sub.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\miniconda3\\envs\\python3\\lib\\site-packages\\ipykernel_launcher.py:1: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n",
      "  \"\"\"Entry point for launching an IPython kernel.\n"
     ]
    }
   ],
   "source": [
    "df_sub['y'] = sub-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sub.columns=['id','y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4456068992182160</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4456424178427250</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4456797466940200</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4456791021108920</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4457086404997440</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 id  y\n",
       "0  4456068992182160  0\n",
       "1  4456424178427250 -1\n",
       "2  4456797466940200  0\n",
       "3  4456791021108920 -1\n",
       "4  4457086404997440 -1"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_sub.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sub.to_csv('test_sub.csv',index=False, encoding='utf-8')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 总结：\n",
    "\n",
    "- transformer包的了解\n",
    "- bert等预训练模型的原理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.naive_bayes import MultinomialNB\n",
    "from sklearn.svm import LinearSVC\n",
    "import numpy as np\n",
    "\n",
    "train = pd.read_csv('./train.csv',sep='  ',engine='python',encoding='utf8')\n",
    "test = pd.read_csv('./test.csv',sep='  ',engine='python',encoding='utf8')\n",
    "print(train.shape)\n",
    "train = train.dropna()\n",
    "\n",
    "train = train.reset_index(drop=True)\n",
    "print(train.shape)\n",
    "label_unique = train['stance'].unique()\n",
    "nb_class = len(label_unique)\n",
    "\n",
    "label_map = {'FAVOR':2, 'AGAINST':1, 'NONE':0}\n",
    "label_map_r = {2:'FAVOR', 1:'AGAINST', 0:'NONE'}\n",
    "\n",
    "import jieba\n",
    "train['text_cut'] = train['text'].apply(lambda x:' '.join(jieba.cut(x)))\n",
    "test['text_cut'] = test['text'].apply(lambda x:' '.join(jieba.cut(x)))\n",
    "\n",
    "train_text = list(train['text_cut'].values)\n",
    "test_text = list(test['text_cut'].values)\n",
    "totle_text = train_text + test_text\n",
    "\n",
    "tf = TfidfVectorizer(min_df=0, ngram_range=(1,2),stop_words='english')\n",
    "tf.fit(totle_text)\n",
    "train['stance'] = train['stance'].map(label_map)\n",
    "X = tf.transform(train_text)\n",
    "y = train['stance'].values\n",
    "X_test = tf.transform(test_text)\n",
    "skf = StratifiedKFold(n_splits=5,random_state=42,shuffle=True)\n",
    "\n",
    "oof_train = np.zeros((len(train),nb_class))\n",
    "oof_test = np.zeros((len(test),nb_class))\n",
    "\n",
    "for idx,(tr_in,te_in) in enumerate(skf.split(X,y)):\n",
    "    X_train = X[tr_in]\n",
    "    X_valid = X[te_in]\n",
    "    y_train = y[tr_in]\n",
    "    y_valid = y[te_in]\n",
    "    #\n",
    "    # lr = LogisticRegression()\n",
    "\n",
    "    lr1 = MultinomialNB(alpha=.25)\n",
    "\n",
    "    # lr = LinearSVC()\n",
    "\n",
    "    lr1.fit(X_train,y_train)\n",
    "    # lr.fit(X_train,y_train)\n",
    "\n",
    "    y_pred = lr1.predict_proba(X_valid)\n",
    "    oof_train[te_in] = y_pred\n",
    "\n",
    "    oof_test = oof_test + lr1.predict_proba(X_test) / skf.n_splits\n",
    "\n",
    "# 使用包大人推荐的方法\n",
    "x1 = np.array(oof_train)\n",
    "y1 = np.array(y)\n",
    "from scipy import optimize\n",
    "def fun(x):\n",
    "    tmp = np.hstack([x[0] * x1[:, 0].reshape(-1, 1), x[1] * x1[:, 1].reshape(-1, 1), x[2] * x1[:, 2].reshape(-1, 1)])\n",
    "    return - accuracy_score(y1, np.argmax(tmp, axis=1))\n",
    "x0 = np.asarray((0,0,0))\n",
    "res = optimize.fmin_powell(fun, x0)\n",
    "\n",
    "xx_score = accuracy_score(y,np.argmax(oof_train,axis=1))\n",
    "print('原始score',xx_score)\n",
    "# bestWght = search_best(oof_train, y)\n",
    "xx_cv = accuracy_score(y,np.argmax(oof_train * res,axis=1))\n",
    "print('修正后的',xx_cv)\n",
    "\n",
    "result = test[['text']].copy()\n",
    "result['label'] = np.argmax(oof_test,axis=1)\n",
    "result['label2'] = np.argmax(oof_test*res,axis=1)\n",
    "# print(result)\n",
    "# result['id'] = test.index + 1\n",
    "result['label'] = result['label'].map(label_map_r)\n",
    "result['label2'] = result['label2'].map(label_map_r)\n",
    "\n",
    "result[['label']].to_csv('./lr_tfidf_{}.csv'.format(str(xx_score).split('.')[1]),header=None,)\n",
    "result[['label2']].to_csv('./lr_tfidf_{}.csv'.format(str(xx_cv).split('.')[1]),header=None,)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
